notebooks/Directed Training Clusters.ipynb (359 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "a450e000-91f7-4a0c-a3b8-c13d449f17fe",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "6f310eb2-14db-47a8-a8a6-e2353d570c28",
"metadata": {},
"source": [
"### This code clusters the training data so some of it can be re-labeled and use in a multi-shot training flow"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "b339b641-b4dd-4893-9c91-a3e3c59e1e59",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0067ab2a-96d3-4a50-97cb-a1c307a89a5f",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.cluster import KMeans\n",
"from bertopic import BERTopic"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "79626719-5913-457a-8b54-bef37cf36ff3",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv(\"../test_data/topic_fine_tuning_data__01_05.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a5954814-a3d7-48b4-b822-bd71c0f53351",
"metadata": {},
"outputs": [],
"source": [
"cluster_model = KMeans(n_clusters=100)\n",
"topic_model = BERTopic(hdbscan_model=cluster_model)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "21fefd42-5cdd-4cf0-af64-2698220ad638",
"metadata": {},
"outputs": [],
"source": [
"df= df.fillna(\"\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "0f895d7e-2288-4d4c-8c26-023aaea667cc",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
}
],
"source": [
"topics, probs = topic_model.fit_transform(df[\"input_titles\"] + \" \" + df[\"input_keywords\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ff70cdf-0de4-4605-96a5-0d273b4c6376",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "657cb274-524f-4cf0-9957-3d222b8bbe17",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ab3f9b77-9143-47cf-8c5d-bca997f81b9d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"100"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(pd.DataFrame({\"topics\": topics}).topics.unique())"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "3a749246-9a62-4c06-98a4-7cd36eccaed2",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Topic</th>\n",
" <th>Count</th>\n",
" <th>Name</th>\n",
" <th>Representation</th>\n",
" <th>Representative_Docs</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>88</td>\n",
" <td>0_food_meal_chefs_recipes</td>\n",
" <td>[food, meal, chefs, recipes, culinary, pantry,...</td>\n",
" <td>[Food Photography Group - Flickr\\nYoung Chefs ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>87</td>\n",
" <td>1_bank_banking_online_account</td>\n",
" <td>[bank, banking, online, account, login, mybank...</td>\n",
" <td>[Account Overview - Online Banking\\nTransfer F...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>58</td>\n",
" <td>2_vacation_flights_travel_rentals</td>\n",
" <td>[vacation, flights, travel, rentals, airbnb, c...</td>\n",
" <td>[Reykjavik Vacation Rentals & Homes - Airbnb\\n...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3</td>\n",
" <td>54</td>\n",
" <td>3_ai_artificial_openai_intelligence</td>\n",
" <td>[ai, artificial, openai, intelligence, future,...</td>\n",
" <td>[Understanding AI: Basics for Everyone\\nStream...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4</td>\n",
" <td>54</td>\n",
" <td>4_climate_disasters_natural_weather</td>\n",
" <td>[climate, disasters, natural, weather, change,...</td>\n",
" <td>[Meteorology Hub: 2023 Climate Predictions\\nFo...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>95</td>\n",
" <td>7</td>\n",
" <td>95_assisted_dying_bill_guardian</td>\n",
" <td>[assisted, dying, bill, guardian, abbott, thir...</td>\n",
" <td>[Newscast - What next for the assisted dying b...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96</th>\n",
" <td>96</td>\n",
" <td>6</td>\n",
" <td>96_hn_hacker_show_internet</td>\n",
" <td>[hn, hacker, show, internet, use, tatedatabrea...</td>\n",
" <td>[Security Engineer\\nMindshift: Break Through O...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>97</th>\n",
" <td>97</td>\n",
" <td>4</td>\n",
" <td>97_printers_printer_amazons_ink</td>\n",
" <td>[printers, printer, amazons, ink, choice, mult...</td>\n",
" <td>[Amazon's Choice Printer Ink , Amazon's Choice...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98</th>\n",
" <td>98</td>\n",
" <td>3</td>\n",
" <td>98_yourtango_taught_boomers_eyes</td>\n",
" <td>[yourtango, taught, boomers, eyes, ginkgonotes...</td>\n",
" <td>[GinkgoNotes: Never forget what you've learned...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>99</th>\n",
" <td>99</td>\n",
" <td>1</td>\n",
" <td>99_perils_motherhood_momfluencers_tradwives</td>\n",
" <td>[perils, motherhood, momfluencers, tradwives, ...</td>\n",
" <td>[Momfluencers, tradwives, and the perils of mo...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>100 rows × 5 columns</p>\n",
"</div>"
],
"text/plain": [
" Topic Count Name \\\n",
"0 0 88 0_food_meal_chefs_recipes \n",
"1 1 87 1_bank_banking_online_account \n",
"2 2 58 2_vacation_flights_travel_rentals \n",
"3 3 54 3_ai_artificial_openai_intelligence \n",
"4 4 54 4_climate_disasters_natural_weather \n",
".. ... ... ... \n",
"95 95 7 95_assisted_dying_bill_guardian \n",
"96 96 6 96_hn_hacker_show_internet \n",
"97 97 4 97_printers_printer_amazons_ink \n",
"98 98 3 98_yourtango_taught_boomers_eyes \n",
"99 99 1 99_perils_motherhood_momfluencers_tradwives \n",
"\n",
" Representation \\\n",
"0 [food, meal, chefs, recipes, culinary, pantry,... \n",
"1 [bank, banking, online, account, login, mybank... \n",
"2 [vacation, flights, travel, rentals, airbnb, c... \n",
"3 [ai, artificial, openai, intelligence, future,... \n",
"4 [climate, disasters, natural, weather, change,... \n",
".. ... \n",
"95 [assisted, dying, bill, guardian, abbott, thir... \n",
"96 [hn, hacker, show, internet, use, tatedatabrea... \n",
"97 [printers, printer, amazons, ink, choice, mult... \n",
"98 [yourtango, taught, boomers, eyes, ginkgonotes... \n",
"99 [perils, motherhood, momfluencers, tradwives, ... \n",
"\n",
" Representative_Docs \n",
"0 [Food Photography Group - Flickr\\nYoung Chefs ... \n",
"1 [Account Overview - Online Banking\\nTransfer F... \n",
"2 [Reykjavik Vacation Rentals & Homes - Airbnb\\n... \n",
"3 [Understanding AI: Basics for Everyone\\nStream... \n",
"4 [Meteorology Hub: 2023 Climate Predictions\\nFo... \n",
".. ... \n",
"95 [Newscast - What next for the assisted dying b... \n",
"96 [Security Engineer\\nMindshift: Break Through O... \n",
"97 [Amazon's Choice Printer Ink , Amazon's Choice... \n",
"98 [GinkgoNotes: Never forget what you've learned... \n",
"99 [Momfluencers, tradwives, and the perils of mo... \n",
"\n",
"[100 rows x 5 columns]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"topic_model.get_topic_info()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "494d20a4-d2e3-4d51-aee2-f5d9ae8db96e",
"metadata": {},
"outputs": [],
"source": [
"df[\"assigned_topic\"] = topics"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "8f263907-38f9-4e17-842f-a5943ced8777",
"metadata": {},
"outputs": [],
"source": [
"df.to_csv(\"../test_data/topic_fine_tuning_data__01_05__grouped.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "418c7c48-8d79-4bc5-a194-fa2c33d563fa",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}